package com.virgo.limit.aspectj;

import com.virgo.common.exception.IApplicationException;
import com.virgo.common.utils.SpringUtils;
import com.virgo.limit.annotation.RateLimiter;
import com.virgo.limit.enums.LimitType;
import com.virgo.web.utils.IRequestUtils;
import jakarta.servlet.http.HttpServletRequest;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.redisson.api.RRateLimiter;
import org.redisson.api.RateType;
import org.redisson.api.RedissonClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.expression.BeanFactoryResolver;
import org.springframework.context.expression.MethodBasedEvaluationContext;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.ParserContext;
import org.springframework.expression.common.TemplateParserContext;
import org.springframework.expression.spel.standard.SpelExpressionParser;

import java.lang.reflect.Method;
import java.time.Duration;

/**
 * 限流处理
 *
 * @author Lion Li
 */
@Aspect
public class RateLimiterAspect {

    /**
     * 定义spel表达式解析器
     */
    private final ExpressionParser parser = new SpelExpressionParser();
    /**
     * 定义spel解析模版
     */
    private final ParserContext parserContext = new TemplateParserContext();
    /**
     * 方法参数解析器
     */
    private final ParameterNameDiscoverer pnd = new DefaultParameterNameDiscoverer();

    @Autowired
    private RedissonClient redissonClient;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
        int time = rateLimiter.time();
        int count = rateLimiter.count();
        try {
            String combineKey = getCombineKey(rateLimiter, point);
            RateType rateType = RateType.OVERALL;
            if (rateLimiter.limitType() == LimitType.USERAGENT) {
                rateType = RateType.PER_CLIENT;
            }
            RRateLimiter limiter = redissonClient.getRateLimiter(combineKey);
            limiter.trySetRate(rateType, count, Duration.ofSeconds(time));
            long number;
            if (limiter.tryAcquire()) {
                number = limiter.availablePermits();
            } else {
                number = -1L;
            }
            if (number == -1) {
                throw new IApplicationException(rateLimiter.message());
            }
//            log.info("限制令牌 => {}, 剩余令牌 => {}, 缓存key => '{}'", count, number, combineKey);
        } catch (Exception e) {
            if (e instanceof IApplicationException) {
                throw e;
            } else {
                throw new RuntimeException("服务器限流异常，请稍候再试", e);
            }
        }
    }

    private String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        String key = rateLimiter.key();
        // 判断 key 不为空 和 不是表达式
        if (key.contains("#")) {
            MethodSignature signature = (MethodSignature) point.getSignature();
            Method targetMethod = signature.getMethod();
            Object[] args = point.getArgs();
            MethodBasedEvaluationContext context =
                new MethodBasedEvaluationContext(null, targetMethod, args, pnd);
            context.setBeanResolver(new BeanFactoryResolver(SpringUtils.getBeanFactory()));
            Expression expression;

            if (key.startsWith(parserContext.getExpressionPrefix()) && key.endsWith(parserContext.getExpressionSuffix())) {
                expression = parser.parseExpression(key, parserContext);
            } else {
                expression = parser.parseExpression(key);
            }
            key = expression.getValue(context, String.class);
        }
        StringBuilder stringBuffer = new StringBuilder("rate_limit:");
        HttpServletRequest request = IRequestUtils.getHttpServletRequest();
        stringBuffer.append(request.getRequestURI()).append(":");
        if (rateLimiter.limitType() == LimitType.IP) {
            // 获取请求ip
            stringBuffer.append(IRequestUtils.getIpAddr(request)).append(":");
        } else if (rateLimiter.limitType() == LimitType.USERAGENT) {
            // 获取客户端实例id
            stringBuffer.append(request.getHeader("User-Agent")).append(":");
        }
        return stringBuffer.append(key).toString();
    }
}
